{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Alice and Bob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "import torchvision\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# python 3.7 is required\n",
    "assert sys.version_info[0] == 3 and sys.version_info[1] == 7, \"python 3.7 is required\"\n",
    "\n",
    "import crypten\n",
    "from crypten import mpc\n",
    "crypten.init()\n",
    "torch.set_num_threads(1)\n",
    "\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save Encrypted Data From Each Party"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "alice_data = torch.tensor([1, 2, 3.0])\n",
    "bob_data = torch.tensor([4, 5, 6.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Alice is party 0\n",
    "ALICE = 0\n",
    "BOB = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[None, None]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@mpc.run_multiprocess(world_size=2)\n",
    "def save_all_data():\n",
    "    crypten.save(alice_data, \"/tmp/data/alice_data.pth\", src=ALICE)\n",
    "    crypten.save(bob_data, \"/tmp/data/bob_data.pth\", src=BOB)\n",
    "    \n",
    "save_all_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 16\r\n",
      "drwxr-xr-x  4 marksibrahim  wheel   128B Apr 17 18:58 \u001b[1m\u001b[36mMNIST\u001b[m\u001b[m\r\n",
      "-rw-r--r--  1 marksibrahim  wheel   351B Apr 17 19:55 alice_data.pth\r\n",
      "-rw-r--r--  1 marksibrahim  wheel   351B Apr 17 19:55 bob_data.pth\r\n"
     ]
    }
   ],
   "source": [
    "! ls -lh /tmp/data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Encrypted Data From Each Party"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'crypten.mpc.mpc.MPCTensor'>\n",
      "<class 'crypten.mpc.mpc.MPCTensor'>\n",
      "alice data is tensor([1., 2., 3.])\n",
      "alice data is tensor([1., 2., 3.])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[None, None]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@mpc.run_multiprocess(world_size=2)\n",
    "def load_data():\n",
    "    alice_data_enc = crypten.load(\"/tmp/data/alice_data.pth\", src=ALICE)\n",
    "    bob_data_enc = crypten.load(\"/tmp/data/bob_data.pth\", src=BOB)\n",
    "    \n",
    "    print(type(alice_data_enc))\n",
    "    print(f\"alice data is {alice_data_enc.get_plain_text()}\")\n",
    "\n",
    "load_data()\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Digits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Alice has Digits. Bob has Labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "digits = torchvision.datasets.MNIST(root='/tmp/data', \n",
    "                                           train=True, \n",
    "                                           transform=torchvision.transforms.ToTensor(),\n",
    "                                           download=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def take_samples(digits, n_samples=1000):\n",
    "    \"\"\"Returns images and labels based on sample size\"\"\"\n",
    "    images, labels = [], []\n",
    "\n",
    "    for i, digit in enumerate(digits):\n",
    "        if i == n_samples:\n",
    "            break\n",
    "        image, label = digit\n",
    "        images.append(image)\n",
    "        label_one_hot = torch.nn.functional.one_hot(torch.tensor(label), 10)\n",
    "        labels.append(label_one_hot)\n",
    "\n",
    "    images = torch.cat(images)\n",
    "    labels = torch.stack(labels)\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "images, labels = take_samples(digits, n_samples=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 28, 28])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save Alice and Bob's Digits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[None, None]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@mpc.run_multiprocess(world_size=2)\n",
    "def save_digits():\n",
    "    crypten.save(images, \"/tmp/data/alice_images.pth\", src=ALICE)\n",
    "    crypten.save(labels, \"/tmp/data/bob_labels.pth\", src=BOB)\n",
    "      \n",
    "save_digits()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clean up tmp directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "! find /tmp/data -name \"*.pth\" -type f -delete"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Full Joint Training\n",
    "see `training_across_parties.py`\n",
    "\n",
    "For a full set of examples with scripts to run on separate AWS machines see https://github.com/facebookresearch/CrypTen/tree/master/examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "crypten-pycon",
   "language": "python",
   "name": "crypten-pycon"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
